{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 10 - Relational Deep Reinforcement Learning\n",
    "### Deep Reinforcement Learning *in Action*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████| 9912422/9912422 [00:08<00:00, 1221139.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST/MNIST/raw/train-images-idx3-ubyte.gz to MNIST/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████| 28881/28881 [00:00<00:00, 12552921.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████| 1648877/1648877 [00:01<00:00, 1283586.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 1829846.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/MNIST/raw\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "import torch\n",
    "from torch import nn\n",
    "import torchvision as TV\n",
    "\n",
    "mnist_data = TV.datasets.MNIST(\"MNIST/\", train=True, transform=None,\\\n",
    "                                        target_transform=None, download=True) #A\n",
    "mnist_test = TV.datasets.MNIST(\"MNIST/\", train=False, transform=None,\\\n",
    "                                        target_transform=None, download=True) #B\n",
    "\n",
    "\n",
    "def add_spots(x,m=20,std=5,val=1): #C\n",
    "    mask = torch.zeros(x.shape)\n",
    "    N = int(m + std * np.abs(np.random.randn()))\n",
    "    ids = np.random.randint(np.prod(x.shape),size=N)\n",
    "    mask.view(-1)[ids] = val\n",
    "    return torch.clamp(x + mask,0,1)\n",
    "\n",
    "def prepare_images(xt,maxtrans=6,rot=5,noise=10): #D\n",
    "    out = torch.zeros(xt.shape)\n",
    "    for i in range(xt.shape[0]):\n",
    "        img = xt[i].unsqueeze(dim=0)\n",
    "        img = TV.transforms.functional.to_pil_image(img)\n",
    "        rand_rot = np.random.randint(-1*rot,rot,1) if rot < 0 else 0\n",
    "        xtrans,ytrans = np.random.randint(-maxtrans,maxtrans,2)\n",
    "        #print(rand_rot[0])\n",
    "        img = TV.transforms.functional.affine(img, rand_rot, (xtrans,ytrans),1,0)\n",
    "        img = TV.transforms.functional.to_tensor(img).squeeze()\n",
    "        if noise > 0:\n",
    "            img = add_spots(img,m=noise)\n",
    "        maxval = img.view(-1).max()\n",
    "        if maxval > 0:\n",
    "            img = img.float() / maxval\n",
    "        else:\n",
    "            img = img.float()\n",
    "        out[i] = img\n",
    "    return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.2/10.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RelationalModule(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(RelationalModule, self).__init__()\n",
    "        self.ch_in = 1\n",
    "        self.conv1_ch = 16 #A\n",
    "        self.conv2_ch = 20\n",
    "        self.conv3_ch = 24\n",
    "        self.conv4_ch = 30\n",
    "        self.H = 28 #B\n",
    "        self.W = 28\n",
    "        self.node_size = 36 #C\n",
    "        self.lin_hid = 100\n",
    "        self.out_dim = 10\n",
    "        self.sp_coord_dim = 2\n",
    "        self.N = int(16**2) #D\n",
    "\n",
    "        self.conv1 = nn.Conv2d(self.ch_in,self.conv1_ch,kernel_size=(4,4))\n",
    "        self.conv2 = nn.Conv2d(self.conv1_ch,self.conv2_ch,kernel_size=(4,4))\n",
    "        self.conv3 = nn.Conv2d(self.conv2_ch,self.conv3_ch,kernel_size=(4,4))\n",
    "        self.conv4 = nn.Conv2d(self.conv3_ch,self.conv4_ch,kernel_size=(4,4))\n",
    "        \n",
    "        self.proj_shape = (self.conv4_ch+self.sp_coord_dim,self.node_size) #E\n",
    "        self.k_proj = nn.Linear(*self.proj_shape)\n",
    "        self.q_proj = nn.Linear(*self.proj_shape)\n",
    "        self.v_proj = nn.Linear(*self.proj_shape)\n",
    "        \n",
    "        self.norm_shape = (self.N,self.node_size)\n",
    "        self.k_norm = nn.LayerNorm(self.norm_shape, elementwise_affine=True) #F\n",
    "        self.q_norm = nn.LayerNorm(self.norm_shape, elementwise_affine=True)\n",
    "        self.v_norm = nn.LayerNorm(self.norm_shape, elementwise_affine=True)\n",
    "        \n",
    "        self.linear1 = nn.Linear(self.node_size, self.node_size)\n",
    "        self.norm1 = nn.LayerNorm([self.N,self.node_size], elementwise_affine=False)\n",
    "        self.linear2 = nn.Linear(self.node_size, self.out_dim)\n",
    "\n",
    "    def forward(self,x):\n",
    "            N, Cin, H, W = x.shape\n",
    "            x = self.conv1(x) \n",
    "            x = torch.relu(x)\n",
    "            x = self.conv2(x) \n",
    "            x = x.squeeze() \n",
    "            x = torch.relu(x) \n",
    "            x = self.conv3(x)\n",
    "            x = torch.relu(x)\n",
    "            x = self.conv4(x)\n",
    "            x = torch.relu(x)\n",
    "\n",
    "            _,_,cH,cW = x.shape\n",
    "            xcoords = torch.arange(cW).repeat(cH,1).float() / cW #G\n",
    "            ycoords = torch.arange(cH).repeat(cW,1).transpose(1,0).float() / cH\n",
    "            spatial_coords = torch.stack([xcoords,ycoords],dim=0)\n",
    "            spatial_coords = spatial_coords.unsqueeze(dim=0)\n",
    "            spatial_coords = spatial_coords.repeat(N,1,1,1) \n",
    "            x = torch.cat([x,spatial_coords],dim=1)\n",
    "            x = x.permute(0,2,3,1)\n",
    "            x = x.flatten(1,2)\n",
    "\n",
    "            K = self.k_proj(x) #H\n",
    "            K = self.k_norm(K) \n",
    "\n",
    "            Q = self.q_proj(x)\n",
    "            Q = self.q_norm(Q) \n",
    "\n",
    "            V = self.v_proj(x)\n",
    "            V = self.v_norm(V) \n",
    "            A = torch.einsum('bfe,bge->bfg',Q,K) #I\n",
    "            A = A / np.sqrt(self.node_size)\n",
    "            A = torch.nn.functional.softmax(A,dim=2) \n",
    "            with torch.no_grad():\n",
    "                self.att_map = A.clone()\n",
    "            E = torch.einsum('bfc,bcd->bfd',A,V) #J\n",
    "            E = self.linear1(E)\n",
    "            E = torch.relu(E)\n",
    "            E = self.norm1(E)  \n",
    "            E = E.max(dim=1)[0]\n",
    "            y = self.linear2(E)  \n",
    "            y = torch.nn.functional.log_softmax(y,dim=1)\n",
    "            return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/torchvision/datasets/mnist.py:75: UserWarning: train_data has been renamed data\n",
      "  warnings.warn(\"train_data has been renamed data\")\n",
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/torchvision/datasets/mnist.py:65: UserWarning: train_labels has been renamed targets\n",
      "  warnings.warn(\"train_labels has been renamed targets\")\n"
     ]
    }
   ],
   "source": [
    "agent = RelationalModule() #A\n",
    "epochs = 1000\n",
    "batch_size=300\n",
    "lr = 1e-3\n",
    "opt = torch.optim.Adam(params=agent.parameters(),lr=lr)\n",
    "lossfn = nn.NLLLoss()\n",
    "for i in range(epochs):\n",
    "    opt.zero_grad()\n",
    "    batch_ids = np.random.randint(0,60000,size=batch_size) #B\n",
    "    xt = mnist_data.train_data[batch_ids].detach()\n",
    "    xt = prepare_images(xt,rot=30).unsqueeze(dim=1) #C\n",
    "    yt = mnist_data.train_labels[batch_ids].detach()\n",
    "    pred = agent(xt)\n",
    "    pred_labels = torch.argmax(pred,dim=1) #D\n",
    "    acc_ = 100.0 * (pred_labels == yt).sum() / batch_size #E\n",
    "    correct = torch.zeros(batch_size,10)\n",
    "    rows = torch.arange(batch_size).long()\n",
    "    correct[[rows,yt.detach().long()]] = 1.\n",
    "    loss = lossfn(pred,yt)\n",
    "    loss.backward()\n",
    "    opt.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/torchvision/datasets/mnist.py:80: UserWarning: test_data has been renamed data\n",
      "  warnings.warn(\"test_data has been renamed data\")\n",
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/torchvision/datasets/mnist.py:70: UserWarning: test_labels has been renamed targets\n",
      "  warnings.warn(\"test_labels has been renamed targets\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.9600)\n"
     ]
    }
   ],
   "source": [
    "def test_acc(model,batch_size=500):\n",
    "    acc = 0.\n",
    "    batch_ids = np.random.randint(0,10000,size=batch_size)\n",
    "    xt = mnist_test.test_data[batch_ids].detach()\n",
    "    xt = prepare_images(xt,maxtrans=6,rot=30,noise=10).unsqueeze(dim=1)\n",
    "    yt = mnist_test.test_labels[batch_ids].detach()\n",
    "    preds = model(xt)\n",
    "    pred_ind = torch.argmax(preds.detach(),dim=1)\n",
    "    acc = (pred_ind == yt).sum().float() / batch_size\n",
    "    return acc, xt, yt\n",
    "\n",
    "acc2, xt2, yt2 = test_acc(agent)\n",
    "print(acc2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7feb0e0cfb80>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgBklEQVR4nO3df3BU9f3v8ddmN9lETFYSJclKIqmXEQXEHwijeFsYMjIZRJmOWh3EXLzT1jYIGIcCbYOtvyK2tRHlC+LcCp0Rf/whaLmjDkUEvfI7YuW25cdXilG+IdpqFoJZkt1z//Cy32+EkATOJ+9sfD5mzh979uR13rPZzStnc3I24HmeJwAAelmG9QAAgG8nCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmQtYDfFMymdShQ4eUm5urQCBgPQ4AoIc8z9ORI0cUjUaVkdH5cU6fK6BDhw6ppKTEegwAwFlqaGjQ4MGDO72/zxVQbm6uJGl89H8qlJHle74XO+J7ZtrLzHQaHwj7/308wcvJdpd9TthZttL16L496Sw60Bp3l51IOMuWJLUedxbtJRw+5ufmOMltTx7X2x8vT/0870yfK6ATb7uFMrIUyvD/B4AXcPckT1sZjgvIwffxBC+YntlpW0Cewx+GQWfRCniOC8jhX9M9l4+5w9empC7/jMJJCAAAExQQAMAEBQQAMEEBAQBMOCugJUuWaMiQIcrOztbYsWO1bds2V7sCAKQhJwX00ksvqbq6Wg888IDq6+s1atQoTZo0SU1NTS52BwBIQ04K6IknntAPf/hDzZgxQ5dddpmWLVumc845R3/4wx9c7A4AkIZ8L6Djx49r586dKi8v/8+dZGSovLxcmzdvPmn7eDyuWCzWYQEA9H++F9Dnn3+uRCKhwsLCDusLCwvV2Nh40va1tbWKRCKphcvwAMC3g/lZcAsWLFBzc3NqaWhosB4JANALfL8Uz/nnn69gMKjDhw93WH/48GEVFRWdtH04HFY47PZyEACAvsf3I6CsrCxdffXVWr9+fWpdMpnU+vXrde211/q9OwBAmnJyMdLq6mpVVlZq9OjRGjNmjOrq6tTS0qIZM2a42B0AIA05KaAf/OAH+uyzz7Rw4UI1Njbqiiuu0BtvvHHSiQkAgG8vZx/HMHPmTM2cOdNVPAAgzZmfBQcA+HaigAAAJiggAIAJCggAYMLZSQhnyzt6VF7guO+5yXjc98ze0NVnq5+VRMJdtiQdb3OXfewrZ9EOH3Gnj7nn8vuZ4e53Vi+ZdJftcG5Jbl9DDh+XQNDR45Ls3s9ujoAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAICJkPUAnQpkfL34HRsI+J6ZkuGwz9N1bkleIuEw3HOX3dbmLNpLJJ1ly3OYna4c/CzpLYHssLNsL3eAm9xESDrc9Xbp+10BAKQ1CggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmPC9gGpra3XNNdcoNzdXgwYN0tSpU7Vnzx6/dwMASHO+F9DGjRtVVVWlLVu2aN26dWpra9MNN9yglpYWv3cFAEhjvl8J4Y033uhwe8WKFRo0aJB27typ7373u37vDgCQppxfiqe5uVmSlJ+ff8r74/G44vF46nYsFnM9EgCgD3B6EkIymdScOXM0btw4jRgx4pTb1NbWKhKJpJaSkhKXIwEA+ginBVRVVaXdu3frxRdf7HSbBQsWqLm5ObU0NDS4HAkA0Ec4ewtu5syZWrt2rTZt2qTBgwd3ul04HFY47O5qrwCAvsn3AvI8T/fee69Wr16tt99+W2VlZX7vAgDQD/heQFVVVVq1apVeffVV5ebmqrGxUZIUiUSUk5Pj9+4AAGnK978BLV26VM3NzRo/fryKi4tTy0svveT3rgAAaczJW3AAAHSFa8EBAExQQAAAExQQAMAEBQQAMOH8WnB9jdfebj3CmQk4/F0hI+AuW1IgGHQX3tbmLNrpCTUuH/Nkej5XXD5PAuc4/hcQl8/xpLvnYbxkoJPc9vZW6d+73o4jIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYCJkPQC6yUu6y066/T0kEHb4NMvJdhYdcJYsqb3dWXQgM9NZtpd3rrPs9gJ32Z9MHOAsW5JK34g5yw5+7jD7KzfPQ6+bz2+OgAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGDCeQE99thjCgQCmjNnjutdAQDSiNMC2r59u5555hldfvnlLncDAEhDzgro6NGjmjZtmp599lkNHDjQ1W4AAGnKWQFVVVVp8uTJKi8vd7ULAEAac3KRrhdffFH19fXavn17l9vG43HF4/HU7VjM3XWPAAB9h+9HQA0NDZo9e7aef/55ZWd3faHI2tpaRSKR1FJSUuL3SACAPsj3Atq5c6eampp01VVXKRQKKRQKaePGjVq8eLFCoZASiUSH7RcsWKDm5ubU0tDQ4PdIAIA+yPe34CZOnKgPP/yww7oZM2Zo2LBhmjdvnoLBYIf7wuGwwuGw32MAAPo43wsoNzdXI0aM6LBuwIABKigoOGk9AODbiyshAABM9Monor799tu9sRsAQBrhCAgAYIICAgCYoIAAACYoIACACQoIAGCiV86COyMZGV8vkCR5Sc9ZdjAvx1m2JMnhPxoHcrq+3NOZhwecRXs57h6TrwbnOcuOn+fuR0Yy2PU2Zyr6f1rdhUvK+LLFXXhrvOttzlDmoX85yQ0kuzczP+EBACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAICJkPUAnfK8rxdIkgKZDr9VIcdPg/PynEUfG3Kes+yMdnfPv+CxdmfZOXsOO8sORwY4yw60J91lfxV3li1Jak84i/YSDh+XpKPneDdzOQICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACScF9Omnn+rOO+9UQUGBcnJyNHLkSO3YscPFrgAAacr3/0D84osvNG7cOE2YMEGvv/66LrjgAu3bt08DBw70e1cAgDTmewEtWrRIJSUleu6551LrysrK/N4NACDN+f4W3GuvvabRo0fr1ltv1aBBg3TllVfq2Wef7XT7eDyuWCzWYQEA9H++F9BHH32kpUuXaujQoXrzzTf1k5/8RLNmzdLKlStPuX1tba0ikUhqKSkp8XskAEAfFPA8f6/4mZWVpdGjR+u9995LrZs1a5a2b9+uzZs3n7R9PB5XPP6fFwqMxWIqKSnRxPz/oVBGlp+jSZKSzWl6hBUMOovOyD3XWbYkqcDd3/++4mKkJ8k89C9n2UkuRnpqLi9G2upu9kB22EluezKuP3+yVM3NzcrL6/xixL4fARUXF+uyyy7rsO7SSy/Vxx9/fMrtw+Gw8vLyOiwAgP7P9wIaN26c9uzZ02Hd3r17ddFFF/m9KwBAGvO9gO677z5t2bJFjz76qPbv369Vq1Zp+fLlqqqq8ntXAIA05nsBXXPNNVq9erVeeOEFjRgxQg899JDq6uo0bdo0v3cFAEhjTj4K88Ybb9SNN97oIhoA0E9wLTgAgAkKCABgggICAJiggAAAJpychOALz5OS/v8nuucgszdkhALOsgNhN/8NfcI/rznfWXb+ts+cZesLd1fNSH7Z7Cy7ve24s+yMAe6uhOD0lXnOOS7TFQi5u1KJ2t1dNUNy+9rvCkdAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADARMh6gM4EIrkKZIT9zz12zPfMVHYg4C472//H4oS2iy5wli1JiSx3j0vzqPOdZefuz3aW7fJ56CUSzrKTX7U6y5aXdBYdzMpyli1JyVaHj0vS3eMSSOa4Ce7mzBwBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwITvBZRIJFRTU6OysjLl5OTo4osv1kMPPSTP8/zeFQAgjfn+j6iLFi3S0qVLtXLlSg0fPlw7duzQjBkzFIlENGvWLL93BwBIU74X0Hvvvaebb75ZkydPliQNGTJEL7zwgrZt2+b3rgAAacz3t+Cuu+46rV+/Xnv37pUkffDBB3r33XdVUVFxyu3j8bhisViHBQDQ//l+BDR//nzFYjENGzZMwWBQiURCjzzyiKZNm3bK7Wtra/XrX//a7zEAAH2c70dAL7/8sp5//nmtWrVK9fX1WrlypX77299q5cqVp9x+wYIFam5uTi0NDQ1+jwQA6IN8PwKaO3eu5s+fr9tvv12SNHLkSB08eFC1tbWqrKw8aftwOKxw2N2VngEAfZPvR0DHjh1TRkbH2GAwqKTDS4oDANKP70dAU6ZM0SOPPKLS0lINHz5c77//vp544gndfffdfu8KAJDGfC+gp556SjU1NfrpT3+qpqYmRaNR/fjHP9bChQv93hUAII35XkC5ubmqq6tTXV2d39EAgH6Ea8EBAExQQAAAExQQAMAEBQQAMOH7SQh+aZxYrGBWtu+5xW/6HpmSzM1xlv3RLQOdZQcSAWfZkpT/N3f/AxZsc/cxH43/PeIsO3T15c6yC//3AWfZiX994SzbO37cWXai2e01JjOyMp1lB3JznWXL1f9nJrv3uuQICABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmAhZD9CZd+b9L+Xl+t+PY1p/4nvmCUmHj2awNeAs+/wP251lS1LrwKCz7M+udJe9t/LfnGV/kTjmLHti5v3OsovfCDvLTjZ97iw7EHT3PJGkQMFAZ9mJ8/OcZQc/a3YTnPS6tRlHQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADDR4wLatGmTpkyZomg0qkAgoDVr1nS43/M8LVy4UMXFxcrJyVF5ebn27dvn17wAgH6ixwXU0tKiUaNGacmSJae8//HHH9fixYu1bNkybd26VQMGDNCkSZPU2tp61sMCAPqPHv/vfkVFhSoqKk55n+d5qqur0y9/+UvdfPPNkqQ//vGPKiws1Jo1a3T77bef3bQAgH7D178BHThwQI2NjSovL0+ti0QiGjt2rDZv3nzKr4nH44rFYh0WAED/52sBNTY2SpIKCws7rC8sLEzd9021tbWKRCKppaSkxM+RAAB9lPlZcAsWLFBzc3NqaWhosB4JANALfC2goqIiSdLhw4c7rD98+HDqvm8Kh8PKy8vrsAAA+j9fC6isrExFRUVav359al0sFtPWrVt17bXX+rkrAECa6/FZcEePHtX+/ftTtw8cOKBdu3YpPz9fpaWlmjNnjh5++GENHTpUZWVlqqmpUTQa1dSpU/2cGwCQ5npcQDt27NCECRNSt6urqyVJlZWVWrFihX72s5+ppaVFP/rRj/Tll1/q+uuv1xtvvKHs7Gz/pgYApL0eF9D48ePleZ1/2l0gENCDDz6oBx988KwGAwD0b+ZnwQEAvp0oIACACQoIAGCCAgIAmOjxSQi95Wf/cZWyjmT6nhtI+h6ZkveP486yL9jxlbPs9nOznGVLUqgl6Cw753N3v0P9t8x7nGVnxtzNPaCt85OEzlbD9y90ll3ycsJZdjJyrrNsSfpqcK6z7MyYu58rGee4OTvZSwS6t38newcAoAsUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMBEyHqAzrz1+lUKhrN9zx380THfM1OCAXfZ7Uln0ZmHY86yJenIyAucZQcSnrPsxLkOH/OYu9/9gsedRat0ygFn2S27L3SWndHm7nkiSfFI0Fl2ZrOzaHMcAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMBEjwto06ZNmjJliqLRqAKBgNasWZO6r62tTfPmzdPIkSM1YMAARaNR3XXXXTp06JCfMwMA+oEeF1BLS4tGjRqlJUuWnHTfsWPHVF9fr5qaGtXX1+uVV17Rnj17dNNNN/kyLACg/+jxlRAqKipUUVFxyvsikYjWrVvXYd3TTz+tMWPG6OOPP1ZpaemZTQkA6HecX4qnublZgUBA55133invj8fjisfjqduxmNvLwgAA+ganJyG0trZq3rx5uuOOO5SXl3fKbWpraxWJRFJLSUmJy5EAAH2EswJqa2vTbbfdJs/ztHTp0k63W7BggZqbm1NLQ0ODq5EAAH2Ik7fgTpTPwYMH9dZbb3V69CNJ4XBY4XDYxRgAgD7M9wI6UT779u3Thg0bVFBQ4PcuAAD9QI8L6OjRo9q/f3/q9oEDB7Rr1y7l5+eruLhYt9xyi+rr67V27VolEgk1NjZKkvLz85WVleXf5ACAtNbjAtqxY4cmTJiQul1dXS1Jqqys1K9+9Su99tprkqQrrriiw9dt2LBB48ePP/NJAQD9So8LaPz48fK8zj9d8HT3AQBwAteCAwCYoIAAACYoIACACQoIAGCCAgIAmHB+MdIzVbruiELBNt9zA//3333PPMFra3eWHcg711m2Egl32ZLyNh9zlu2dl+ss+9K/JZ1l6wt3F90NhILOshP15znLPudzd5fh8vIjzrIlyQt1frWXsxWKtTrLlquzlruZyxEQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwEbIeoDPBT/+pYEaW77mJ48d9zzzBSyScZSf++S9n2c592ewu+z8a3WU7FAi5e+m5fB4GPnf3PPRysp1lBwIBZ9mSFIrkuAs/3uYuOyvTTa7XvWMbjoAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgIkeF9CmTZs0ZcoURaNRBQIBrVmzptNt77nnHgUCAdXV1Z3FiACA/qjHBdTS0qJRo0ZpyZIlp91u9erV2rJli6LR6BkPBwDov3r833AVFRWqqKg47Taffvqp7r33Xr355puaPHnyGQ8HAOi/fP8bUDKZ1PTp0zV37lwNHz7c73gAQD/h+/VAFi1apFAopFmzZnVr+3g8rng8nrodi8X8HgkA0Af5egS0c+dOPfnkk1qxYkW3r71UW1urSCSSWkpKSvwcCQDQR/laQO+8846amppUWlqqUCikUCikgwcP6v7779eQIUNO+TULFixQc3NzamloaPBzJABAH+XrW3DTp09XeXl5h3WTJk3S9OnTNWPGjFN+TTgcVjgc9nMMAEAa6HEBHT16VPv370/dPnDggHbt2qX8/HyVlpaqoKCgw/aZmZkqKirSJZdccvbTAgD6jR4X0I4dOzRhwoTU7erqaklSZWWlVqxY4dtgAID+rccFNH78eHme1+3t//GPf/R0FwCAbwGuBQcAMEEBAQBMUEAAABMUEADABAUEADDh+7Xg/JK4sECBYLbvucEc/zNPiJfmO8v2Mrp3aaMzkch2+3tIzsdHnGVnfOku22s55izbJe+rr5xlB7KynGUnL3Z3Ga7jF+Q4y5ak0JE2Z9nJc93NnsxxUwHt7d3bjiMgAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgImQ9QDf5HmeJKk9EXeTn3STK0nt7a3Osr2MgLPsRJvb30NcfS8lKcPh99PzjjvLdsnl3AHPWbSSDp8n7e3uXj9f76DNWXRGIuEsO9ne7iS3vf3r7+WJn+edCXhdbdHLPvnkE5WUlFiPAQA4Sw0NDRo8eHCn9/e5Akomkzp06JByc3MVCHT9W0ssFlNJSYkaGhqUl5fXCxP6g7l7V7rOLaXv7Mzdu/rS3J7n6ciRI4pGo8rI6Pwdlj73FlxGRsZpG7MzeXl55g/6mWDu3pWuc0vpOztz966+MnckEulyG05CAACYoIAAACbSvoDC4bAeeOABhcNh61F6hLl7V7rOLaXv7Mzdu9Jx7j53EgIA4Nsh7Y+AAADpiQICAJiggAAAJiggAICJtC6gJUuWaMiQIcrOztbYsWO1bds265G6VFtbq2uuuUa5ubkaNGiQpk6dqj179liP1WOPPfaYAoGA5syZYz1Klz799FPdeeedKigoUE5OjkaOHKkdO3ZYj3VaiURCNTU1KisrU05Oji6++GI99NBDXV5by8KmTZs0ZcoURaNRBQIBrVmzpsP9nudp4cKFKi4uVk5OjsrLy7Vv3z6bYf+L083d1tamefPmaeTIkRowYICi0ajuuusuHTp0yG7g/6+rx/u/uueeexQIBFRXV9dr8/VE2hbQSy+9pOrqaj3wwAOqr6/XqFGjNGnSJDU1NVmPdlobN25UVVWVtmzZonXr1qmtrU033HCDWlparEfrtu3bt+uZZ57R5Zdfbj1Kl7744guNGzdOmZmZev311/XXv/5Vv/vd7zRw4EDr0U5r0aJFWrp0qZ5++mn97W9/06JFi/T444/rqaeesh7tJC0tLRo1apSWLFlyyvsff/xxLV68WMuWLdPWrVs1YMAATZo0Sa2t7i7e2x2nm/vYsWOqr69XTU2N6uvr9corr2jPnj266aabDCbtqKvH+4TVq1dry5YtikajvTTZGfDS1JgxY7yqqqrU7UQi4UWjUa+2ttZwqp5ramryJHkbN260HqVbjhw54g0dOtRbt26d973vfc+bPXu29UinNW/ePO/666+3HqPHJk+e7N19990d1n3/+9/3pk2bZjRR90jyVq9enbqdTCa9oqIi7ze/+U1q3ZdffumFw2HvhRdeMJjw1L4596ls27bNk+QdPHiwd4bqhs7m/uSTT7wLL7zQ2717t3fRRRd5v//973t9tu5IyyOg48ePa+fOnSovL0+ty8jIUHl5uTZv3mw4Wc81NzdLkvLz840n6Z6qqipNnjy5w2Pfl7322msaPXq0br31Vg0aNEhXXnmlnn32WeuxunTddddp/fr12rt3ryTpgw8+0LvvvquKigrjyXrmwIEDamxs7PB8iUQiGjt2bFq+VgOBgM477zzrUU4rmUxq+vTpmjt3roYPH249zmn1uYuRdsfnn3+uRCKhwsLCDusLCwv197//3Wiqnksmk5ozZ47GjRunESNGWI/TpRdffFH19fXavn279Sjd9tFHH2np0qWqrq7Wz3/+c23fvl2zZs1SVlaWKisrrcfr1Pz58xWLxTRs2DAFg0ElEgk98sgjmjZtmvVoPdLY2ChJp3ytnrgvHbS2tmrevHm64447+sSFPk9n0aJFCoVCmjVrlvUoXUrLAuovqqqqtHv3br377rvWo3SpoaFBs2fP1rp165SdnW09Trclk0mNHj1ajz76qCTpyiuv1O7du7Vs2bI+XUAvv/yynn/+ea1atUrDhw/Xrl27NGfOHEWj0T49d3/U1tam2267TZ7naenSpdbjnNbOnTv15JNPqr6+vlsfZ2MtLd+CO//88xUMBnX48OEO6w8fPqyioiKjqXpm5syZWrt2rTZs2HBGHz/R23bu3KmmpiZdddVVCoVCCoVC2rhxoxYvXqxQKKSEw09tPBvFxcW67LLLOqy79NJL9fHHHxtN1D1z587V/Pnzdfvtt2vkyJGaPn267rvvPtXW1lqP1iMnXo/p+lo9UT4HDx7UunXr+vzRzzvvvKOmpiaVlpamXqcHDx7U/fffryFDhliPd5K0LKCsrCxdffXVWr9+fWpdMpnU+vXrde211xpO1jXP8zRz5kytXr1ab731lsrKyqxH6paJEyfqww8/1K5du1LL6NGjNW3aNO3atUvBYNB6xFMaN27cSae57927VxdddJHRRN1z7Nixkz7IKxgMKplMGk10ZsrKylRUVNThtRqLxbR169Y+/1o9UT779u3Tn//8ZxUUFFiP1KXp06frL3/5S4fXaTQa1dy5c/Xmm29aj3eStH0Lrrq6WpWVlRo9erTGjBmjuro6tbS0aMaMGdajnVZVVZVWrVqlV199Vbm5uan3wSORiHJycoyn61xubu5Jf6caMGCACgoK+vTfr+677z5dd911evTRR3Xbbbdp27ZtWr58uZYvX2492mlNmTJFjzzyiEpLSzV8+HC9//77euKJJ3T33Xdbj3aSo0ePav/+/anbBw4c0K5du5Sfn6/S0lLNmTNHDz/8sIYOHaqysjLV1NQoGo1q6tSpdkPr9HMXFxfrlltuUX19vdauXatEIpF6rebn5ysrK8tq7C4f728WZWZmpoqKinTJJZf09qhdsz4N72w89dRTXmlpqZeVleWNGTPG27Jli/VIXZJ0yuW5556zHq3H0uE0bM/zvD/96U/eiBEjvHA47A0bNsxbvny59UhdisVi3uzZs73S0lIvOzvb+853vuP94he/8OLxuPVoJ9mwYcMpn9OVlZWe5319KnZNTY1XWFjohcNhb+LEid6ePXtsh/ZOP/eBAwc6fa1u2LChz859Kn35NGw+jgEAYCIt/wYEAEh/FBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATPw/xF4qV84d/Q0AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(agent.att_map[0].max(dim=0)[0].view(16,16))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNN(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1,10,kernel_size=(4,4)) #A\n",
    "        self.conv2 = nn.Conv2d(10,16,kernel_size=(4,4))\n",
    "        self.conv3 = nn.Conv2d(16,24,kernel_size=(4,4))\n",
    "        self.conv4 = nn.Conv2d(24,32,kernel_size=(4,4))\n",
    "        self.maxpool1 = nn.MaxPool2d(kernel_size=(2,2)) #B\n",
    "        self.conv5 = nn.Conv2d(32,64,kernel_size=(4,4))\n",
    "        self.lin1 = nn.Linear(256,128)\n",
    "        self.out = nn.Linear(128,10) #C\n",
    "    def forward(self,x):\n",
    "        x = self.conv1(x)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = self.maxpool1(x)\n",
    "        x = self.conv3(x)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = self.conv4(x)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = self.conv5(x)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = x.flatten(start_dim=1)\n",
    "        x = self.lin1(x)\n",
    "        x = nn.functional.relu(x)\n",
    "        x = self.out(x)\n",
    "        x = nn.functional.log_softmax(x,dim=1) #D\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 3, 7, 7])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from einops import rearrange\n",
    "x = torch.randn(5,7,7,3)\n",
    "rearrange(x, \"batch h w c -> batch c h w\").shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadRelationalModule(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MultiHeadRelationalModule, self).__init__()\n",
    "        self.conv1_ch = 16 \n",
    "        self.conv2_ch = 20\n",
    "        self.conv3_ch = 24\n",
    "        self.conv4_ch = 30\n",
    "        self.H = 28\n",
    "        self.W = 28\n",
    "        self.node_size = 64\n",
    "        self.lin_hid = 100\n",
    "        self.out_dim = 5\n",
    "        self.ch_in = 3\n",
    "        self.sp_coord_dim = 2\n",
    "        self.N = int(7**2)\n",
    "        self.n_heads = 3\n",
    "        \n",
    "        self.conv1 = nn.Conv2d(self.ch_in,self.conv1_ch,kernel_size=(1,1),padding=0) #A\n",
    "        self.conv2 = nn.Conv2d(self.conv1_ch,self.conv2_ch,kernel_size=(1,1),padding=0)\n",
    "        self.proj_shape = (self.conv2_ch+self.sp_coord_dim,self.n_heads * self.node_size)\n",
    "        self.k_proj = nn.Linear(*self.proj_shape)\n",
    "        self.q_proj = nn.Linear(*self.proj_shape)\n",
    "        self.v_proj = nn.Linear(*self.proj_shape)\n",
    "\n",
    "        self.k_lin = nn.Linear(self.node_size,self.N) #B\n",
    "        self.q_lin = nn.Linear(self.node_size,self.N)\n",
    "        self.a_lin = nn.Linear(self.N,self.N)\n",
    "        \n",
    "        self.node_shape = (self.n_heads, self.N,self.node_size)\n",
    "        self.k_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True)\n",
    "        self.q_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True)\n",
    "        self.v_norm = nn.LayerNorm(self.node_shape, elementwise_affine=True)\n",
    "        \n",
    "        self.linear1 = nn.Linear(self.n_heads * self.node_size, self.node_size)\n",
    "        self.norm1 = nn.LayerNorm([self.N,self.node_size], elementwise_affine=False)\n",
    "        self.linear2 = nn.Linear(self.node_size, self.out_dim)\n",
    "    \n",
    "    def forward(self,x):\n",
    "        N, Cin, H, W = x.shape\n",
    "        x = self.conv1(x) \n",
    "        x = torch.relu(x)\n",
    "        x = self.conv2(x) \n",
    "        x = torch.relu(x) \n",
    "        with torch.no_grad(): \n",
    "            self.conv_map = x.clone() #C\n",
    "        _,_,cH,cW = x.shape\n",
    "        xcoords = torch.arange(cW).repeat(cH,1).float() / cW\n",
    "        ycoords = torch.arange(cH).repeat(cW,1).transpose(1,0).float() / cH\n",
    "        spatial_coords = torch.stack([xcoords,ycoords],dim=0)\n",
    "        spatial_coords = spatial_coords.unsqueeze(dim=0)\n",
    "        spatial_coords = spatial_coords.repeat(N,1,1,1)\n",
    "        x = torch.cat([x,spatial_coords],dim=1)\n",
    "        x = x.permute(0,2,3,1)\n",
    "        x = x.flatten(1,2)\n",
    "        \n",
    "        K = rearrange(self.k_proj(x), \"b n (head d) -> b head n d\", head=self.n_heads)\n",
    "        K = self.k_norm(K) \n",
    "        \n",
    "        Q = rearrange(self.q_proj(x), \"b n (head d) -> b head n d\", head=self.n_heads)\n",
    "        Q = self.q_norm(Q) \n",
    "        \n",
    "        V = rearrange(self.v_proj(x), \"b n (head d) -> b head n d\", head=self.n_heads)\n",
    "        V = self.v_norm(V) \n",
    "        A = torch.nn.functional.elu(self.q_lin(Q) + self.k_lin(K)) #D\n",
    "        A = self.a_lin(A)\n",
    "        A = torch.nn.functional.softmax(A,dim=3) \n",
    "        with torch.no_grad():\n",
    "            self.att_map = A.clone() #E\n",
    "        E = torch.einsum('bhfc,bhcd->bhfd',A,V) #F\n",
    "        E = rearrange(E, 'b head n d -> b n (head d)')\n",
    "        E = self.linear1(E)\n",
    "        E = torch.relu(E)\n",
    "        E = self.norm1(E)\n",
    "        E = E.max(dim=1)[0]\n",
    "        y = self.linear2(E)\n",
    "        y = torch.nn.functional.elu(y)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/gym/envs/registration.py:307: DeprecationWarning: The package name gym_minigrid has been deprecated in favor of minigrid. Please uninstall gym_minigrid and install minigrid with `pip install minigrid`. Future releases will be maintained under the new package name minigrid.\n",
      "  fn()\n"
     ]
    }
   ],
   "source": [
    "import gymnasium as gym\n",
    "from gym_minigrid.minigrid import *\n",
    "from gym_minigrid.wrappers import FullyObsWrapper, ImgObsWrapper\n",
    "from skimage.transform import resize\n",
    "\n",
    "def prepare_state(x): #A\n",
    "    ns = torch.from_numpy(x).float().permute(2,0,1).unsqueeze(dim=0)#\n",
    "    maxv = ns.flatten().max()\n",
    "    ns = ns / maxv\n",
    "    return ns\n",
    "\n",
    "def get_minibatch(replay,size): #B\n",
    "    batch_ids = np.random.randint(0,len(replay),size)\n",
    "    batch = [replay[x] for x in batch_ids] #list of tuples\n",
    "    state_batch = torch.cat([s for (s,a,r,s2,d) in batch],)\n",
    "    action_batch = torch.Tensor([a for (s,a,r,s2,d) in batch]).long()\n",
    "    reward_batch = torch.Tensor([r for (s,a,r,s2,d) in batch])\n",
    "    state2_batch = torch.cat([s2 for (s,a,r,s2,d) in batch],dim=0)\n",
    "    done_batch = torch.Tensor([d for (s,a,r,s2,d) in batch])\n",
    "    return state_batch,action_batch,reward_batch,state2_batch, done_batch\n",
    "\n",
    "def get_qtarget_ddqn(qvals,r,df,done): #C\n",
    "    targets = r + (1-done) * df * qvals\n",
    "    return targets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lossfn(pred,targets,actions): #A\n",
    "    loss = torch.mean(torch.pow(\\\n",
    "                                targets.detach() -\\\n",
    "                                pred.gather(dim=1,index=actions.unsqueeze(dim=1)).squeeze()\\\n",
    "                                ,2),dim=0)\n",
    "    return loss\n",
    "  \n",
    "def update_replay(replay,exp,replay_size): #B\n",
    "    r = exp[2]\n",
    "    N = 1\n",
    "    if r > 0:\n",
    "        N = 50\n",
    "    for i in range(N):\n",
    "        replay.append(exp)\n",
    "    return replay\n",
    "\n",
    "action_map = { #C\n",
    "    0:0, \n",
    "    1:1,\n",
    "    2:2,\n",
    "    3:3,\n",
    "    4:5,\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 10.10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/don/git/DeepReinforcementLearningInAction/venv/lib/python3.10/site-packages/gym/utils/passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
      "  if not isinstance(terminated, (bool, np.bool8)):\n"
     ]
    }
   ],
   "source": [
    "from collections import deque\n",
    "env = ImgObsWrapper(gym.make('MiniGrid-DoorKey-5x5-v0', render_mode=\"rgb_array\")) #A\n",
    "state = prepare_state(env.reset()[0]) \n",
    "GWagent = MultiHeadRelationalModule() #B\n",
    "Tnet = MultiHeadRelationalModule() #C\n",
    "maxsteps = 400 #D\n",
    "env.max_steps = maxsteps\n",
    "env.env.max_steps = maxsteps\n",
    "\n",
    "epochs = 50000\n",
    "replay_size = 9000\n",
    "batch_size = 50\n",
    "lr = 0.0005\n",
    "gamma = 0.99\n",
    "replay = deque(maxlen=replay_size) #E\n",
    "opt = torch.optim.Adam(params=GWagent.parameters(),lr=lr)\n",
    "eps = 0.5\n",
    "update_freq = 100\n",
    "for i in range(epochs):\n",
    "    pred = GWagent(state)\n",
    "    action = int(torch.argmax(pred).detach().numpy())\n",
    "    if np.random.rand() < eps: #F\n",
    "        action = int(torch.randint(0,5,size=(1,)).squeeze())\n",
    "    action_d = action_map[action]\n",
    "    state2, reward, done, _, info = env.step(action_d)\n",
    "    reward = -0.01 if reward == 0 else reward #G\n",
    "    state2 = prepare_state(state2)\n",
    "    exp = (state,action,reward,state2,done)\n",
    "    \n",
    "    replay = update_replay(replay,exp,replay_size)\n",
    "    if done:\n",
    "        state = prepare_state(env.reset()[0])\n",
    "    else:\n",
    "        state = state2\n",
    "    if len(replay) > batch_size:\n",
    "        \n",
    "        opt.zero_grad()\n",
    "        \n",
    "        state_batch,action_batch,reward_batch,state2_batch,done_batch = get_minibatch(replay,batch_size)\n",
    "        \n",
    "        q_pred = GWagent(state_batch).cpu()\n",
    "        astar = torch.argmax(q_pred,dim=1)\n",
    "        qs = Tnet(state2_batch).gather(dim=1,index=astar.unsqueeze(dim=1)).squeeze()\n",
    "        \n",
    "        targets = get_qtarget_ddqn(qs.detach(),reward_batch.detach(),gamma,done_batch)\n",
    "        \n",
    "        loss = lossfn(q_pred,targets.detach(),action_batch)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(GWagent.parameters(), max_norm=1.0) #H\n",
    "        opt.step()\n",
    "    if i % update_freq == 0: #I\n",
    "        Tnet.load_state_dict(GWagent.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fead10f6c80>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZgAAAGdCAYAAAAv9mXmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAWTUlEQVR4nO3db2xVhf348c+FyoVpWwEB6Siomw4RYSpCGLr5h2n4qdE9cMZgRphZoqlTJCaGJ8NliWVZZtwfgqKb+GAMNxPUmS8wxgRilMifkKBLVJSFTgTm4trSB1ek9/dgv3W/fhXmbfvhcuvrlZzEe3JOz+dE7bvnnNvbQrlcLgcADLAh1R4AgMFJYABIITAApBAYAFIIDAApBAaAFAIDQAqBASBF3ck+YHd3dxw4cCDq6+ujUCic7MMD0A/lcjk6Ozujqakphgw58TXKSQ/MgQMHorm5+WQfFoAB1NbWFhMmTDjhNic9MPX19RERcUX8n6iL00724QHoh4/jaLwc/9PzvfxETnpg/n1brC5Oi7qCwADUlP/36ZWf5RGHh/wApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACn6FJjly5fHOeecE8OHD49Zs2bFa6+9NtBzAVDjKg7MM888E4sXL46lS5fGrl27Yvr06XH99dfH4cOHM+YDoEZVHJhHHnkkvve978XChQtjypQp8dhjj8UXvvCF+PWvf50xHwA1qqLAfPTRR7Fz586YO3fuf77AkCExd+7cePXVVz91n1KpFB0dHb0WAAa/igLzwQcfxLFjx2LcuHG91o8bNy4OHjz4qfu0trZGY2Njz9Lc3Nz3aQGoGenvIluyZEm0t7f3LG1tbdmHBOAUUFfJxmeddVYMHTo0Dh061Gv9oUOH4uyzz/7UfYrFYhSLxb5PCEBNqugKZtiwYXHZZZfFpk2betZ1d3fHpk2bYvbs2QM+HAC1q6IrmIiIxYsXx4IFC2LGjBkxc+bMePTRR6OrqysWLlyYMR8ANariwNx2223x97//PX7wgx/EwYMH46tf/WqsX7/+Ew/+Afh8K5TL5fLJPGBHR0c0NjbGVXFz1BVOO5mHBqCfPi4fjc3xfLS3t0dDQ8MJt/VZZACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSVByYrVu3xk033RRNTU1RKBTiueeeSxgLgFpXcWC6urpi+vTpsXz58ox5ABgk6irdYd68eTFv3ryMWQAYRCoOTKVKpVKUSqWe1x0dHdmHBOAUkP6Qv7W1NRobG3uW5ubm7EMCcApID8ySJUuivb29Z2lra8s+JACngPRbZMViMYrFYvZhADjF+D0YAFJUfAVz5MiR2Lt3b8/rffv2xe7du2PUqFExceLEAR0OgNpVcWB27NgRV199dc/rxYsXR0TEggULYtWqVQM2GAC1reLAXHXVVVEulzNmAWAQ8QwGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQIqKAtPa2hqXX3551NfXx9ixY+OWW26JN998M2s2AGpYRYHZsmVLtLS0xLZt22Ljxo1x9OjRuO6666KrqytrPgBqVF0lG69fv77X61WrVsXYsWNj586d8fWvf31ABwOgtlUUmP+tvb09IiJGjRp13G1KpVKUSqWe1x0dHf05JAA1os8P+bu7u2PRokUxZ86cmDp16nG3a21tjcbGxp6lubm5r4cEoIb0OTAtLS3x+uuvx5o1a0643ZIlS6K9vb1naWtr6+shAaghfbpFds8998SLL74YW7dujQkTJpxw22KxGMVisU/DAVC7KgpMuVyO73//+7F27drYvHlznHvuuVlzAVDjKgpMS0tLrF69Op5//vmor6+PgwcPRkREY2NjjBgxImVAAGpTRc9gVqxYEe3t7XHVVVfF+PHje5Znnnkmaz4AalTFt8gA4LPwWWQApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUlT0J5MB6INCodoTDKBCRPmzbekKBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkqCgwK1asiGnTpkVDQ0M0NDTE7NmzY926dVmzAVDDKgrMhAkTYtmyZbFz587YsWNHXHPNNXHzzTfHG2+8kTUfADWqUC6Xy/35AqNGjYqf/OQnceedd36m7Ts6OqKxsTGuipujrnBafw4NUBsKhWpPMGA+Lh+NzeXnor29PRoaGk64bV1fD3Ls2LH4/e9/H11dXTF79uzjblcqlaJUKvW87ujo6OshAaghFT/k37NnT5xxxhlRLBbjrrvuirVr18aUKVOOu31ra2s0Njb2LM3Nzf0aGIDaUPEtso8++ij2798f7e3t8eyzz8aTTz4ZW7ZsOW5kPu0Kprm52S0y4PPjc3qLrN/PYObOnRtf+tKX4vHHH/9M23sGA3zufE4D0+/fg+nu7u51hQIAERU+5F+yZEnMmzcvJk6cGJ2dnbF69erYvHlzbNiwIWs+AGpURYE5fPhwfOc734n3338/GhsbY9q0abFhw4b45je/mTUfADWqosD86le/ypoDgEHGZ5EBkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKSoq/YAAINeuVztCQZOBefiCgaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApOhXYJYtWxaFQiEWLVo0QOMAMFj0OTDbt2+Pxx9/PKZNmzaQ8wAwSPQpMEeOHIn58+fHE088ESNHjhzomQAYBPoUmJaWlrjhhhti7ty5/3XbUqkUHR0dvRYABr+6SndYs2ZN7Nq1K7Zv3/6Ztm9tbY0f/vCHFQ8GQG2r6Aqmra0t7rvvvvjNb34Tw4cP/0z7LFmyJNrb23uWtra2Pg0KQG2p6Apm586dcfjw4bj00kt71h07diy2bt0av/zlL6NUKsXQoUN77VMsFqNYLA7MtADUjIoCc+2118aePXt6rVu4cGFMnjw5HnzwwU/EBYDPr4oCU19fH1OnTu217vTTT4/Ro0d/Yj0An29+kx+AFBW/i+x/27x58wCMAcBg4woGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFIIDAApBAaAFAIDQAqBASBFXbUOvPatPdFQP3j6dn3TV6s9Av/FkPr6ao8w4Lo7O6s9wsArFKo9wcArl6s9QVUMnu/wAJxSBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIUVFgHnrooSgUCr2WyZMnZ80GQA2rq3SHiy66KP70pz/95wvUVfwlAPgcqLgOdXV1cfbZZ2fMAsAgUvEzmLfffjuamprivPPOi/nz58f+/ftPuH2pVIqOjo5eCwCDX0WBmTVrVqxatSrWr18fK1asiH379sWVV14ZnZ2dx92ntbU1Ghsbe5bm5uZ+Dw3Aqa9QLpfLfd35n//8Z0yaNCkeeeSRuPPOOz91m1KpFKVSqed1R0dHNDc3x4dvnRcN9YPnTWzXN3212iPwXwypr6/2CAOu+wQ/3NWsQqHaEwy8vn+bPeV8XD4am+P5aG9vj4aGhhNu268n9GeeeWZccMEFsXfv3uNuUywWo1gs9ucwANSgfl1CHDlyJN55550YP378QM0DwCBRUWAeeOCB2LJlS/z1r3+NV155Jb71rW/F0KFD4/bbb8+aD4AaVdEtsr/97W9x++23xz/+8Y8YM2ZMXHHFFbFt27YYM2ZM1nwA1KiKArNmzZqsOQAYZAbP27gAOKUIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUddU68LcuuDjqCqdV6/B8DnV3dlZ7BD6LcrnaEwy8QqHaEwygQsRn/FfkCgaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApKg4MO+9917ccccdMXr06BgxYkRcfPHFsWPHjozZAKhhdZVs/OGHH8acOXPi6quvjnXr1sWYMWPi7bffjpEjR2bNB0CNqigwP/7xj6O5uTmeeuqpnnXnnnvugA8FQO2r6BbZCy+8EDNmzIhbb701xo4dG5dcckk88cQTJ9ynVCpFR0dHrwWAwa+iwLz77ruxYsWKOP/882PDhg1x9913x7333htPP/30cfdpbW2NxsbGnqW5ubnfQwNw6iuUy+XyZ9142LBhMWPGjHjllVd61t17772xffv2ePXVVz91n1KpFKVSqed1R0dHNDc3x1Vxc9QVTuvH6AA1olCo9gQD5uPy0dhcfi7a29ujoaHhhNtWdAUzfvz4mDJlSq91F154Yezfv/+4+xSLxWhoaOi1ADD4VRSYOXPmxJtvvtlr3VtvvRWTJk0a0KEAqH0VBeb++++Pbdu2xcMPPxx79+6N1atXx8qVK6OlpSVrPgBqVEWBufzyy2Pt2rXx29/+NqZOnRo/+tGP4tFHH4358+dnzQdAjaro92AiIm688ca48cYbM2YBYBDxWWQApBAYAFIIDAApBAaAFAIDQAqBASCFwACQQmAASCEwAKQQGABSCAwAKQQGgBQCA0AKgQEghcAAkEJgAEghMACkEBgAUlT8J5P7q1wuR0TEx3E0onyyjw5QDYVqDzBgPi4fjYj/fC8/kZMemM7OzoiIeDn+52QfGqA6BuEP052dndHY2HjCbQrlz5KhAdTd3R0HDhyI+vr6KBTyqt7R0RHNzc3R1tYWDQ0Nacc5mZzTqW+wnU+Ec6oVJ+ucyuVydHZ2RlNTUwwZcuKnLCf9CmbIkCExYcKEk3a8hoaGQfMf0L85p1PfYDufCOdUK07GOf23K5d/85AfgBQCA0CKQRuYYrEYS5cujWKxWO1RBoxzOvUNtvOJcE614lQ8p5P+kB+Az4dBewUDQHUJDAApBAaAFAIDQIpBGZjly5fHOeecE8OHD49Zs2bFa6+9Vu2R+mXr1q1x0003RVNTUxQKhXjuueeqPVK/tLa2xuWXXx719fUxduzYuOWWW+LNN9+s9lj9smLFipg2bVrPL7nNnj071q1bV+2xBtSyZcuiUCjEokWLqj1Knz300ENRKBR6LZMnT672WP3y3nvvxR133BGjR4+OESNGxMUXXxw7duyo9lgRMQgD88wzz8TixYtj6dKlsWvXrpg+fXpcf/31cfjw4WqP1mddXV0xffr0WL58ebVHGRBbtmyJlpaW2LZtW2zcuDGOHj0a1113XXR1dVV7tD6bMGFCLFu2LHbu3Bk7duyIa665Jm6++eZ44403qj3agNi+fXs8/vjjMW3atGqP0m8XXXRRvP/++z3Lyy+/XO2R+uzDDz+MOXPmxGmnnRbr1q2Lv/zlL/HTn/40Ro4cWe3R/qU8yMycObPc0tLS8/rYsWPlpqamcmtraxWnGjgRUV67dm21xxhQhw8fLkdEecuWLdUeZUCNHDmy/OSTT1Z7jH7r7Owsn3/++eWNGzeWv/GNb5Tvu+++ao/UZ0uXLi1Pnz692mMMmAcffLB8xRVXVHuM4xpUVzAfffRR7Ny5M+bOnduzbsiQITF37tx49dVXqzgZJ9Le3h4REaNGjaryJAPj2LFjsWbNmujq6orZs2dXe5x+a2lpiRtuuKHX/1e17O23346mpqY477zzYv78+bF///5qj9RnL7zwQsyYMSNuvfXWGDt2bFxyySXxxBNPVHusHoMqMB988EEcO3Ysxo0b12v9uHHj4uDBg1WaihPp7u6ORYsWxZw5c2Lq1KnVHqdf9uzZE2eccUYUi8W46667Yu3atTFlypRqj9Uva9asiV27dkVra2u1RxkQs2bNilWrVsX69etjxYoVsW/fvrjyyit7/oxIrXn33XdjxYoVcf7558eGDRvi7rvvjnvvvTeefvrpao8WEVX4NGX4/7W0tMTrr79e0/fB/+0rX/lK7N69O9rb2+PZZ5+NBQsWxJYtW2o2Mm1tbXHffffFxo0bY/jw4dUeZ0DMmzev55+nTZsWs2bNikmTJsXvfve7uPPOO6s4Wd90d3fHjBkz4uGHH46IiEsuuSRef/31eOyxx2LBggVVnm6QXcGcddZZMXTo0Dh06FCv9YcOHYqzzz67SlNxPPfcc0+8+OKL8dJLL53UP+GQZdiwYfHlL385LrvssmhtbY3p06fHz372s2qP1Wc7d+6Mw4cPx6WXXhp1dXVRV1cXW7ZsiZ///OdRV1cXx44dq/aI/XbmmWfGBRdcEHv37q32KH0yfvz4T/wAc+GFF54yt/0GVWCGDRsWl112WWzatKlnXXd3d2zatGlQ3AsfLMrlctxzzz2xdu3a+POf/xznnntutUdK0d3dHaVSqdpj9Nm1114be/bsid27d/csM2bMiPnz58fu3btj6NCh1R6x344cORLvvPNOjB8/vtqj9MmcOXM+8Rb/t956KyZNmlSliXobdLfIFi9eHAsWLIgZM2bEzJkz49FHH42urq5YuHBhtUfrsyNHjvT6CWvfvn2xe/fuGDVqVEycOLGKk/VNS0tLrF69Op5//vmor6/veT7W2NgYI0aMqPJ0fbNkyZKYN29eTJw4MTo7O2P16tWxefPm2LBhQ7VH67P6+vpPPBc7/fTTY/To0TX7vOyBBx6Im266KSZNmhQHDhyIpUuXxtChQ+P222+v9mh9cv/998fXvva1ePjhh+Pb3/52vPbaa7Fy5cpYuXJltUf7l2q/jS3DL37xi/LEiRPLw4YNK8+cObO8bdu2ao/ULy+99FI5/vVXvXstCxYsqPZoffJp5xIR5aeeeqrao/XZd7/73fKkSZPKw4YNK48ZM6Z87bXXlv/4xz9We6wBV+tvU77tttvK48ePLw8bNqz8xS9+sXzbbbeV9+7dW+2x+uUPf/hDeerUqeVisViePHlyeeXKldUeqYeP6wcgxaB6BgPAqUNgAEghMACkEBgAUggMACkEBoAUAgNACoEBIIXAAJBCYABIITAApBAYAFL8X2ryMlhxCuaaAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "state_ = env.reset()\n",
    "state = prepare_state(state_[0])\n",
    "GWagent(state)\n",
    "plt.imshow(env.render())\n",
    "plt.imshow(state[0].permute(1,2,0).detach().numpy())\n",
    "head, node = 2, 26\n",
    "plt.imshow(GWagent.att_map[0][head][node].view(7,7))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
